import sys, numpy as np
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileProgram, compileShader
import math

# ---------------- Shader Sources ----------------
VERTEX_SRC = """
#version 330
layout (location = 0) in vec2 pos;
out vec2 texCoord;
void main() {
    texCoord = (pos + 1.0) * 0.5;
    gl_Position = vec4(pos, 0.0, 1.0);
}
"""

FRAGMENT_SRC = """
#version 330
in vec2 texCoord;
out vec4 fragColor;

uniform sampler2D latticeTex;   // 32×72 lattice
uniform sampler2D historyTex;   // scrolling history
uniform float cycle;
uniform float omegaTime;        // analog Ω overlay

// Threshold function
float binarize(float x) {
    return x > 0.5 ? 1.0 : 0.0;
}

void main() {
    int x = int(texCoord.x * 32.0);    // lattice slot
    int y = int(texCoord.y * 72.0);    // lattice instance

    // fetch current slot state
    float val = texelFetch(latticeTex, ivec2(x, y), 0).r;

    // --- HDGL toy update ---
    float phi = 1.6180339887;
    float r_dim = 0.3 + 0.01 * float(y);       // per-instance recursion bias
    float omega = 0.5 + 0.5*sin(omegaTime);   // analog Ω oscillation
    float new_val = val + r_dim * 0.5 + omega * 0.25;
    
    float slot = binarize(mod(new_val + float((int(cycle)+x+y)%2), 2.0));

    fragColor = vec4(slot, slot, slot, 1.0);
}
"""

# ---------------- Globals ----------------
window = None
shader = None
vao = None
lattice_tex = None
history_tex = None
cycle = 0.0
omega_time = 0.0

# ---------------- OpenGL Init ----------------
def init_gl():
    global shader, vao, lattice_tex, history_tex

    shader = compileProgram(
        compileShader(VERTEX_SRC, GL_VERTEX_SHADER),
        compileShader(FRAGMENT_SRC, GL_FRAGMENT_SHADER)
    )

    # Fullscreen quad
    verts = np.array([-1,-1, 1,-1, -1,1, 1,-1, 1,1, -1,1], dtype=np.float32)
    vao = glGenVertexArrays(1)
    glBindVertexArray(vao)
    vbo = glGenBuffers(1)
    glBindBuffer(GL_ARRAY_BUFFER, vbo)
    glBufferData(GL_ARRAY_BUFFER, verts.nbytes, verts, GL_STATIC_DRAW)
    glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 0, None)
    glEnableVertexAttribArray(0)

    # Lattice state texture: 32 slots × 72 instances
    init_lattice = np.zeros((72,32), dtype=np.float32)
    for i in range(72):
        init_lattice[i,0] = 1.0  # seed one active slot per instance
    lattice_tex = glGenTextures(1)
    glBindTexture(GL_TEXTURE_2D, lattice_tex)
    glTexImage2D(GL_TEXTURE_2D, 0, GL_R32F, 32, 72, 0, GL_RED, GL_FLOAT, init_lattice)
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)

    # History texture (same width, arbitrary height 100 cycles)
    history_tex = glGenTextures(1)
    glBindTexture(GL_TEXTURE_2D, history_tex)
    glTexImage2D(GL_TEXTURE_2D, 0, GL_R32F, 32, 100, 0, GL_RED, GL_FLOAT, None)
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)

# ---------------- Display ----------------
def display():
    global cycle, omega_time
    glClear(GL_COLOR_BUFFER_BIT)

    glUseProgram(shader)
    glUniform1i(glGetUniformLocation(shader, "latticeTex"), 0)
    glUniform1i(glGetUniformLocation(shader, "historyTex"), 1)
    glUniform1f(glGetUniformLocation(shader, "cycle"), cycle)
    glUniform1f(glGetUniformLocation(shader, "omegaTime"), omega_time)

    glActiveTexture(GL_TEXTURE0)
    glBindTexture(GL_TEXTURE_2D, lattice_tex)

    glActiveTexture(GL_TEXTURE1)
    glBindTexture(GL_TEXTURE_2D, history_tex)

    glBindVertexArray(vao)
    glDrawArrays(GL_TRIANGLES, 0, 6)

    glutSwapBuffers()
    cycle += 1
    omega_time += 0.05  # smooth Ω oscillation

def idle():
    glutPostRedisplay()

# ---------------- Main ----------------
def main():
    global window
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE)
    glutInitWindowSize(1280, 720)
    window = glutCreateWindow(b"HDGL GPU Lattice w/ PHI Clock")
    init_gl()
    glutDisplayFunc(display)
    glutIdleFunc(idle)
    glutMainLoop()

if __name__ == "__main__":
    main()
